from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore import Tensor
import numpy as np
from mindspore import train
"""Training while evaluating
 Args:
     network_model: network model
     training_data: training sets.
     batch_size:  Integer. Number of samples per gradient update.
     accuracy: float. Target accuracy of the model
     earlyStopEpoch: Interger or None.
     validation_data: validation sets.
     validation_split: Fraction of the training data to be used as validation data.
     validation_data: validation sets.
     prefix: ModelCheckpoint.prefix
     epochs: Integer. Number of epochs to train the model.
     callbacks_cb: List of mindspore.train.callback instances. List of callbacks to apply during training.
     verbose: 'auto', 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. 
     shuffle: Boolean (whether to shuffle the training data before each epoch) or str (for 'batch').
     class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only).
     sample_weight:  Optional Numpy array of weights for the training samples, used for weighting the loss function (during training only).
     initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run).
     steps_per_epoch: Integer or None. Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch.
     sink_mode: sink mode(True or False)
 """
def fit(network_model,
        training_data,
        batch_size,
        accuracy,
        earlyStopEpoch=None,
        validation_data=None,
        validation_split=0.0,
        prefix="tmp",
        epochs=1,
        callbacks_cb=None,
        verbose="auto",
        shuffle=True,
        class_weight=None,
        sample_weight=None,
        initial_epoch=0,
        steps_per_epoch=None,
        sink_mode=False
        ):
    if not isinstance(network_model, type(train.Model)) and network_model is not None:
        if training_data is not None and batch_size is not None:

            # 基本设置保存每步的结果
            config_ck = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=10)
            ckpoint = ModelCheckpoint(prefix="tmp", directory=r"./tmp", config=config_ck)

            #处理定义训练集
            if validation_split > 0.0 and validation_split <1.0:
                print("")
                #重新定义验证集
                #do something
                #validation_data = split * train_dataset (split: 0.0 -> 1.0)
            elif validation_split == 0.0:
                print("")
                #训练集有效
                #validation_data = validation_data

            #处理用户定义的回调函数
            if callbacks_cb is not None:
                callbacks_cb_tmp = [ckpoint, FitMonitor(monitor_model=network_model, validation_data=validation_data, monitor_accuracy=accuracy, dataset_sink_mode=False)]+callbacks_cb
            else:
                callbacks_cb_tmp = [ckpoint, FitMonitor(monitor_model=network_model, validation_data=validation_data, monitor_accuracy=accuracy, dataset_sink_mode=False)]

            #处理是否进行earlyStopping
            if earlyStopEpoch is not None:
                callbacks_cb_tmp = callbacks_cb_tmp+[EarlyStopping(patience=earlyStopEpoch)]
            else:
                callbacks_cb_tmp = callbacks_cb_tmp

            #处理是否为下沉模式
            if sink_mode == False:
                sink_mode_tmp = False
            else:
                sink_mode_tmp = True
                #下沉模式处理逻辑

            #开始训练
            network_model.train(epochs, training_data, callbacks=callbacks_cb_tmp, dataset_sink_mode=sink_mode_tmp)


class FitMonitor(Callback):
    """
    Evaluate the model accuracy in the process of model training

    if the accuracy meet the set value, model training will be terminated.

    Note:
        If per_print_times is 0, do not print loss.

    Args:
        per_print_times (int): Print the loss each every seconds. Default: 1.

    Raises:
        ValueError: If per_print_times is not an integer or less than zero.
    """
    def __init__(self, monitor_model, validation_data, monitor_accuracy, dataset_sink_mode=False, per_print_times=1, ):
        super(FitMonitor, self).__init__()
        if not isinstance(per_print_times, int) or per_print_times < 0:
            raise ValueError("print_step must be int and >= 0.")
        self._per_print_times = per_print_times
        self.monitor_model = monitor_model
        self.validation_data = validation_data
        self.monitor_accuracy = monitor_accuracy
        self.dataset_sink_mode = dataset_sink_mode
        if validation_data is None:
            raise ValueError("dataset valid")
            return
        if not isinstance(monitor_accuracy,float):
            raise ValueError("accuracy valid")
            return
    def step_end(self, run_context):
        """
                Called after each step finished.

                Args:
                    run_context (RunContext): Include some information of the model.

                这里完成
                1)实时评估模型精度
         """
        cb_params = run_context.original_args()
        loss = cb_params.net_outputs
        if isinstance(loss, (tuple, list)):
            if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
                loss = loss[0]

        if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
            loss = np.mean(loss.asnumpy())

        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1

        if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
            raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
                cb_params.cur_epoch_num, cur_step_in_epoch))
        if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
            x = 1
            if x:
                Accuracy = self.monitor_model.eval(self.validation_data, dataset_sink_mode=self.dataset_sink_mode)
                print("epoch: %s step: %s, loss is %s, accuracy is %f" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss, Accuracy['Accuracy']), flush=True)
                checkAccuracy(params=cb_params, set_accuracy=self.monitor_accuracy, cur_accuracy=Accuracy['Accuracy'])
            else:
                print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)


def checkAccuracy(params, set_accuracy, cur_accuracy):
    if cur_accuracy >= set_accuracy:
        params.request_stop()


class EarlyStopping(Callback):
    """
    Stop the training after the number of patience of epoches

    Args:
        patience (int): patience is a threshold value

    """
    def __init__(self, patience):
        super(EarlyStopping, self).__init__()
        self._patience = patience

    def epoch_begin(self, run_context):
        cb_params = run_context.original_args()
        if self._patience is not None:
            if cb_params.cur_epoch_num > self._patience:
                cb_params.request_top()
